{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing your own lifting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we show how you can implement your own lifting and test it on a dataset. \n",
    "\n",
    "This particular example uses the MUTAG dataset. The lifting for this example is similar to the SimplicialCliqueLifting but discards the cliques that are bigger than the maximum simplices we want to consider.\n",
    "\n",
    "We test this lifting using the SCN2 model from `TopoModelX`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>\n",
    "&emsp;[1. Imports](##sec1)\n",
    "\n",
    "&emsp;[2. Configurations and utilities](##sec2)\n",
    "\n",
    "&emsp;[3. Defining the lifting](##sec2)\n",
    "\n",
    "&emsp;[4. Loading the data](##sec3)\n",
    "\n",
    "&emsp;[5. Model initialization](##sec4)\n",
    "\n",
    "&emsp;[6. Training](##sec5)\n",
    "\n",
    "&emsp;[7. Testing the model](##sec6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Imports <a class=\"anchor\" id=\"sec1\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "from typing import Any\n",
    "\n",
    "import lightning as pl\n",
    "import torch_geometric\n",
    "import networkx as nx\n",
    "\n",
    "# Hydra related imports\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "# Date related imports\n",
    "from toponetx.classes import SimplicialComplex\n",
    "from topobench.data.loaders.graph import TUDatasetLoader\n",
    "from topobench.data.preprocessor import PreProcessor\n",
    "from topobench.dataloader import TBDataloader\n",
    "from topobench.transforms.liftings.graph2simplicial import (\n",
    "    Graph2SimplicialLifting,\n",
    ")\n",
    "\n",
    "# Model related imports\n",
    "from topobench.model.model import TBModel\n",
    "from topomodelx.nn.simplicial.scn2 import SCN2\n",
    "from topobench.nn.wrappers.simplicial import SCNWrapper\n",
    "from topobench.nn.encoders import AllCellFeatureEncoder\n",
    "from topobench.nn.readouts import PropagateSignalDown\n",
    "\n",
    "# Optimization related imports\n",
    "from topobench.loss.loss import TBLoss\n",
    "from topobench.optimizer import TBOptimizer\n",
    "from topobench.evaluator.evaluator import TBEvaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Configurations and utilities <a class=\"anchor\" id=\"sec2\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Configurations can be specified using yaml files or directly specified in your code like in this example. To keep the notebook clean here we already define the configuration for the lifting, which is defined later in the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader_config = {\n",
    "    \"data_domain\": \"graph\",\n",
    "    \"data_type\": \"TUDataset\",\n",
    "    \"data_name\": \"MUTAG\",\n",
    "    \"data_dir\": \"./data/MUTAG/\"}\n",
    "\n",
    "\n",
    "transform_config = { \"clique_lifting\":\n",
    "    {\"_target_\": \"__main__.SimplicialCliquesLEQLifting\",\n",
    "     \"transform_name\": \"SimplicialCliquesLEQLifting\",\n",
    "    \"transform_type\": \"lifting\",\n",
    "    \"complex_dim\": 3,}\n",
    "}\n",
    "\n",
    "split_config = {\n",
    "    \"learning_setting\": \"inductive\",\n",
    "    \"split_type\": \"k-fold\",\n",
    "    \"data_seed\": 0,\n",
    "    \"data_split_dir\": \"./data/MUTAG/splits/\",\n",
    "    \"k\": 10,\n",
    "}\n",
    "\n",
    "in_channels = 7\n",
    "out_channels = 2\n",
    "dim_hidden = 128\n",
    "\n",
    "wrapper_config = {\n",
    "    \"out_channels\": dim_hidden,\n",
    "    \"num_cell_dimensions\": 3,\n",
    "}\n",
    "\n",
    "readout_config = {\n",
    "    \"readout_name\": \"PropagateSignalDown\",\n",
    "    \"num_cell_dimensions\": 1,\n",
    "    \"hidden_dim\": dim_hidden,\n",
    "    \"out_channels\": out_channels,\n",
    "    \"task_level\": \"graph\",\n",
    "    \"pooling_type\": \"sum\",\n",
    "}\n",
    "\n",
    "loss_config = {\n",
    "    \"dataset_loss\": \n",
    "        {\n",
    "            \"task\": \"classification\", \n",
    "            \"loss_type\": \"cross_entropy\"\n",
    "        }\n",
    "}\n",
    "\n",
    "evaluator_config = {\"task\": \"classification\",\n",
    "                    \"num_classes\": out_channels,\n",
    "                    \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n",
    "\n",
    "optimizer_config = {\"optimizer_id\": \"Adam\",\n",
    "                    \"parameters\":\n",
    "                        {\"lr\": 0.001,\"weight_decay\": 0.0005}\n",
    "                    }\n",
    "\n",
    "\n",
    "loader_config = OmegaConf.create(loader_config)\n",
    "transform_config = OmegaConf.create(transform_config)\n",
    "split_config = OmegaConf.create(split_config)\n",
    "wrapper_config = OmegaConf.create(wrapper_config)\n",
    "readout_config = OmegaConf.create(readout_config)\n",
    "loss_config = OmegaConf.create(loss_config)\n",
    "evaluator_config = OmegaConf.create(evaluator_config)\n",
    "optimizer_config = OmegaConf.create(optimizer_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrapper(**factory_kwargs):\n",
    "    def factory(backbone):\n",
    "        return SCNWrapper(backbone, **factory_kwargs)\n",
    "    return factory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Defining the lifting  <a class=\"anchor\" id=\"sec3\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we define the lifting we intend on using. The `SimplicialCliquesLEQLifting` finds the cliques that have a number of nodes less or equal to the maximum simplices we want to consider and creates simplices from them. The configuration for the lifting was already defined with the other configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimplicialCliquesLEQLifting(Graph2SimplicialLifting):\n",
    "    r\"\"\"Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n",
    "    \n",
    "    Args:\n",
    "        kwargs (optional): Additional arguments for the class.\n",
    "    \"\"\"\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    def lift_topology(self, data: torch_geometric.data.Data) -> dict:\n",
    "        r\"\"\"Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n",
    "\n",
    "        Args:\n",
    "            data (torch_geometric.data.Data): The input data to be lifted.\n",
    "        Returns:\n",
    "            dict: The lifted topology.\n",
    "        \"\"\"\n",
    "        graph = self._generate_graph_from_data(data)\n",
    "        simplicial_complex = SimplicialComplex(graph)\n",
    "        cliques = nx.find_cliques(graph)\n",
    "        \n",
    "        simplices: list[set[tuple[Any, ...]]] = [set() for _ in range(2, self.complex_dim + 1)]\n",
    "        for clique in cliques:\n",
    "            if len(clique) <= self.complex_dim + 1:\n",
    "                for i in range(2, self.complex_dim + 1):\n",
    "                    for c in combinations(clique, i + 1):\n",
    "                        simplices[i - 2].add(tuple(c))\n",
    "\n",
    "        for set_k_simplices in simplices:\n",
    "            simplicial_complex.add_simplices_from(list(set_k_simplices))\n",
    "\n",
    "        return self._get_lifted_topology(simplicial_complex, graph)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Loading the data  <a class=\"anchor\" id=\"sec4\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example we use the MUTAG dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from topobench.transforms import TRANSFORMS\n",
    "\n",
    "TRANSFORMS[\"SimplicialCliquesLEQLifting\"] = SimplicialCliquesLEQLifting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transform parameters are the same, using existing data_dir: data/MUTAG/MUTAG/clique_lifting/458544608\n"
     ]
    }
   ],
   "source": [
    "graph_loader = TUDatasetLoader(loader_config)\n",
    "\n",
    "dataset, dataset_dir = graph_loader.load()\n",
    "\n",
    "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n",
    "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n",
    "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Model initialization  <a class=\"anchor\" id=\"sec5\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can create the backbone by instantiating the SCN2 model form TopoModelX. Then the `SCNWrapper` and the `TBModel` take care of the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone = SCN2(in_channels_0=dim_hidden,in_channels_1=dim_hidden,in_channels_2=dim_hidden)\n",
    "backbone_wrapper = wrapper(**wrapper_config)\n",
    "\n",
    "readout = PropagateSignalDown(**readout_config)\n",
    "loss = TBLoss(**loss_config)\n",
    "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n",
    "\n",
    "evaluator = TBEvaluator(**evaluator_config)\n",
    "optimizer = TBOptimizer(**optimizer_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TBModel(backbone=backbone,\n",
    "                 backbone_wrapper=backbone_wrapper,\n",
    "                 readout=readout,\n",
    "                 loss=loss,\n",
    "                 feature_encoder=feature_encoder,\n",
    "                 evaluator=evaluator,\n",
    "                 optimizer=optimizer,\n",
    "                 compile=False,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Training  <a class=\"anchor\" id=\"sec6\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use the `lightning` trainer to train the model. We are prompted to connet a Wandb account to monitor training, but we can also obtain the final training metrics from the trainer directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.\n",
      "\n",
      "  | Name            | Type                  | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | feature_encoder | AllCellFeatureEncoder | 53.8 K | train\n",
      "1 | backbone        | SCNWrapper            | 99.1 K | train\n",
      "2 | readout         | PropagateSignalDown   | 258    | train\n",
      "3 | val_acc_best    | MeanMetric            | 0      | train\n",
      "------------------------------------------------------------------\n",
      "153 K     Trainable params\n",
      "0         Non-trainable params\n",
      "153 K     Total params\n",
      "0.612     Total estimated model params size (MB)\n",
      "36        Modules in train mode\n",
      "0         Modules in eval mode\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/projects/dev/TopoBench/topobench/nn/wrappers/simplicial/scn_wrapper.py:75: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.)\n",
      "  normalized_matrix = diag_matrix @ (matrix @ diag_matrix)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=50` reached.\n"
     ]
    }
   ],
   "source": [
    "# Increase the number of epochs to get better results\n",
    "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n",
    "\n",
    "trainer.fit(model, datamodule)\n",
    "train_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Training metrics\n",
      " --------------------------\n",
      "train/accuracy:       0.7574\n",
      "train/precision:      0.7289\n",
      "train/recall:         0.7308\n",
      "val/loss:             0.6770\n",
      "val/accuracy:         0.7895\n",
      "val/precision:        0.7750\n",
      "val/recall:           0.7115\n",
      "train/loss:           0.5597\n"
     ]
    }
   ],
   "source": [
    "print('      Training metrics\\n', '-'*26)\n",
    "for key in train_metrics:\n",
    "    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Testing the model  <a class=\"anchor\" id=\"sec7\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can test the model and obtain the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test/accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7894737124443054     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6769953966140747     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test/precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7749999761581421     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test/recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7115384340286255     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test/accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7894737124443054    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6769953966140747    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test/precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7749999761581421    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test/recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7115384340286255    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.test(model, datamodule)\n",
    "test_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Testing metrics\n",
      " -------------------------\n",
      "test/loss:           0.6770\n",
      "test/accuracy:       0.7895\n",
      "test/precision:      0.7750\n",
      "test/recall:         0.7115\n"
     ]
    }
   ],
   "source": [
    "print('      Testing metrics\\n', '-'*25)\n",
    "for key in test_metrics:\n",
    "    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "topobench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
